Tree-Based Methods

Module 12

Ray J. Hoobler

Libraries

Code
library(tidyverse)
library(ISLR2)
# install.packages("tree")
library(tree)

Basics of Decision Trees – Regression

Hitters Data Set from ISLR2

Code
hitters <- na.omit(Hitters) |> as_tibble()
hitters

Regression Using the Hitters

Code
# fit the data to a simple regression model based on the number of years played and the number of hits from the previous year
# restrict the number of splits to 3 by setting the mincut parameter to 50

hitters_tree_8_1 <- tree(log(Salary) ~ Years + Hits, data = hitters, mincut = 50)
plot(hitters_tree_8_1)
text(hitters_tree_8_1, digits = 3)

Plot of Data with Decision Tree Values

The red segments are based on splits from the decision tree

Code
hitters |> 
  ggplot() +
  geom_point(aes(x = Years, y = Hits), color = "blue", position = position_jitter(width = 0.1, height = 0.1), shape = 21) +
  geom_segment(x = 4.5, xend = 25, y = 117.5, color = "red") +
  geom_vline(xintercept = 4.5, color = "red") +
  annotate("text", x = 24, y = 125, label = "117.5", color = "red") +
  annotate("text", x = 3.9, y = 0, label = "4.5", color = "red") +
  annotate("text", x = 2.5, y = 190, label = "R1", color = "black", size = 6) +
  annotate("text", x = 11.4, y = 18, label = "R2", color = "black", size = 6) +
  annotate("text", x = 16.3, y = 190, label = "R3", color = "black", size = 6) +
  theme_light()

Generating Splits?

Goal is to find regions \(R_1, \ldots, R_J\) that minimize the residual sum of squares (RSS) within each region.

\[ \sum_{j=1}^{J} \sum_{i \in R_j} (y_i - \hat{y}_{R_j})^2 \]

Regression Tree Analysis for the Hitters Data (ISLR2 Figure 8.4)

Nine (unnamed features) were included in the model from figure 8.4; however, I couldn’t reproduce the same results, so I used 10 features that were not based on cumulative statistics.

Code
set.seed(1)

train <- sample(1:nrow(hitters), nrow(hitters)/2 + 1)
# hitters[train, ]
# hitters[-train,]

# AtBat + Hits + HmRun + Runs + RBI + Walks + Years + PutOuts + Assists + Errors

hitters_tree_8_4 <- tree(
  formula = log(Salary) ~ AtBat + Hits + HmRun + Runs + RBI + Walks + Years + PutOuts + Assists + Errors, 
  data = hitters, 
  control = tree.control(nobs = length(train), minsize = 2, mindev = 0.01),
  subset = train)

# par(pin = c(6, 4))  # Sets physical dimensions in inches (width, height)
plot(hitters_tree_8_4)
text(hitters_tree_8_4, pretty = 0, digits = 3, cex = 0.7, adj = c(0.5, 0.8))

Regression Tree Analysis for the Hitters Dataset (ISLR2 Figure 8.5)

Code
set.seed(1)
cv_hitters_tree_8_4 <- cv.tree(hitters_tree_8_4, FUN = prune.tree, K = 10)
plot(cv_hitters_tree_8_4$size, cv_hitters_tree_8_4$dev, type = "b")
Code
cv_hitters_tree_8_4
$size
 [1] 15 14 13 10  9  8  7  6  5  3  2  1

$dev
 [1]  59.06288  58.79853  58.79853  56.57428  56.46547  56.55295  60.15679
 [8]  61.31232  61.31232  63.74406  74.28324 108.79830

$k
 [1]      -Inf  1.231152  1.235242  1.380117  1.525860  1.668804  2.652881
 [8]  4.457421  5.020619  7.099941 11.126931 43.728005

$method
[1] "deviance"

attr(,"class")
[1] "prune"         "tree.sequence"

Pruning the Tree

Cost complexity pruning

\[ \sum_{m=1}^{|T|} \sum_{i: \,i \in R_m} (y_i - \hat{y}_{R_m})^2 + \alpha |T| \]

Code
prune_cv_hitters_tree_8_4 <- prune.tree(hitters_tree_8_4, best = 9)
plot(prune_cv_hitters_tree_8_4)
text(prune_cv_hitters_tree_8_4, pretty = 0, digits = 3)

Summary of the Pruned Tree

Code
summary(prune_cv_hitters_tree_8_4)

Regression tree:
snip.tree(tree = hitters_tree_8_4, nodes = c(18L, 12L, 5L))
Variables actually used in tree construction:
[1] "Years" "Runs"  "Hits"  "AtBat" "HmRun"
Number of terminal nodes:  9 
Residual mean deviance:  0.1634 = 20.1 / 123 
Distribution of residuals:
     Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
-1.046000 -0.249100 -0.001706  0.000000  0.236500  1.161000 

Prediction From Tree-Based Regression

Code
salary_pred <- predict(prune_cv_hitters_tree_8_4, newdata = hitters[-train,])

tibble(
  salary_pred = salary_pred,
  actual_salary = hitters[-train,]$Salary
) |>
  ggplot() +
  geom_point(aes(x = actual_salary, y = exp(salary_pred)), color = "blue", position = position_jitter(width = 10, height = 10), shape = 21) +
  geom_abline(color = "red", linetype = "dashed") +
  labs(
    x = "Actual Salary",
    y = "Predicted Salary",
    title = "Predicted vs Actual Salaries"
  ) +
  theme_light()

Basics of Decision Trees – Classification

Classification Crieteria

For regression, we used RSS as the criterion for splitting. This is not an option for classification.

Options:

Classification Error Rate: The fraction of training observations in a region that do not belong to the most common class.

\[ E = 1 - \max_k \hat{p}_{mk} \]

\(\hat{p}_{mk}\) is the proportion of observations in the \(m\)th region from the \(k\)th class.

Gini Index: A measure of total variance across the \(K\) classes.

\[ G = \sum_{k=1}^{K} \hat{p}_{mk} (1 - \hat{p}_{mk}) \]

The Gini index is small if all the \(\hat{p}_{mk}\) are close to 0 or 1; indicating a node with mostly one class.

Entropy: A measure of disorder in a region.

\[ D = -\sum_{k=1}^{K} \hat{p}_{mk} \log \hat{p}_{mk} \]

Entropy will be near zero if the \(\hat{p}_{mk}\) are all near zero or one; again, indicating a node with mostly one class.

Classification Using the Hart Data

UCI Machine Learning Repository: Heart Disease

Attribute Information: – Only 14 used – 1. #3 (age)
– 2. #4 (sex)
– 3. #9 (cp)
– 4. #10 (trestbps)
– 5. #12 (chol)
– 6. #16 (fbs)
– 7. #19 (restecg)
– 8. #32 (thalach)
– 9. #38 (exang)
– 10. #40 (oldpeak)
– 11. #41 (slope)
– 12. #44 (ca)
– 13. #51 (thal)
– 14. #58 (num) (the predicted attribute)

Classification Using the Hart Data (cont.)

3 age: age in years 4 sex: sex (1 = male; 0 = female)

9 cp: chest pain type – Value 1: typical angina – Value 2: atypical angina – Value 3: non-anginal pain – Value 4: asymptomatic 10 trestbps: resting blood pressure (in mm Hg on admission to the hospital)

12 chol: serum cholestoral in mg/dl

16 fbs: (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)

19 restecg: resting electrocardiographic results
– Value 0: normal
– Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
– Value 2: showing probable or definite left ventricular hypertrophy by Estes’ criteria

32 thalach: maximum heart rate achieved

38 exang: exercise induced angina (1 = yes; 0 = no)

40 oldpeak = ST depression induced by exercise relative to rest 41 slope: the slope of the peak exercise ST segment

44 ca: number of major vessels (0-3) colored by flourosopy

51 thal: 3 = normal; 6 = fixed defect; 7 = reversable defect

58 num: diagnosis of heart disease (angiographic disease status)

  • Value 0: < 50% diameter narrowing
  • Value 1: > 50% diameter narrowing

(in any major vessel: attributes 59 through 68 are vessels)

Classification Using the Hart Data (cont.)

Inspect the data

Code
# Using the abreviations above the column names are
heart_col_names <- c("age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", "thalach", "exang", "oldpeak", "slope", "ca", "thal", "num")

heart <- read_csv("datasets/heart+disease/processed.cleveland.data", col_names = heart_col_names, na = "?")
heart
Code
heart |> count(num)

Classification Using the Hart Data (cont.)

From the article, describing the num variable:

The fluoroscopic data consisted of the number of major vessels that appeared to contain calcium.

Set num as a binary factor

Code
heart_clean <- heart |>
  mutate(num_binary = if_else(num == 0, 0, 1)) |>
  select(-num) |> 
  mutate(
    num_binary = factor(num_binary, levels = c(0, 1), labels = c("no", "yes")),
    sex = factor(sex, levels = c(0, 1), labels = c("female", "male")),
    cp = as.factor(cp),
    fbs = as.factor(fbs),
    restecg = as.factor(restecg),
    exang = as.factor(exang),
    slope = as.factor(slope),
    ca = as.numeric(ca),
    thal = as.factor(thal)
    )

heart_clean

Classification Tree for the Hart Data

Code
set.seed(1)
train <- sample(1:nrow(heart_clean), nrow(heart_clean)/2 + 1)
# summary(heart_tree)

heart_tree <- tree(num_binary ~ ., data = heart_clean, subset = train)
plot(heart_tree)
text(heart_tree, pretty = 0, digits = 3)

Cross-Validation of the Hart Data Classifiation Model

Code
set.seed(1)
cv_heart_tree <- cv.tree(heart_tree, FUN = prune.tree, K = 10)
plot(cv_heart_tree$size, cv_heart_tree$dev, type = "b")

Plot of pruned tree

Code
prune_heart_tree <- prune.tree(heart_tree, best = 6)
plot(prune_heart_tree)
text(prune_heart_tree, pretty = 0, digits = 3)

Prediction From Tree-Based Classification

True Negative (N) True Positive (P) Total
Predicted Negative (N) TN FN N*
Predicted Positive (P) FP TP P*
Total N P Total
Code
heart_pred <- predict(prune_heart_tree, newdata = heart_clean[-train,], type = "class")
table(heart_pred, heart_clean[-train,]$num_binary)
          
heart_pred no yes
       no  51  13
       yes 27  60
Code
# Accuracy TP + TN / TP + TN + FP + FN
paste("Accuracy ( (TP + TN) / (TP + TN + FP + FN) ):", round(mean(heart_pred == heart_clean[-train,]$num_binary), 3))
[1] "Accuracy ( (TP + TN) / (TP + TN + FP + FN) ): 0.735"
Code
# Sensitivity, Recall (1 - Type II Error)
paste("Sensitivity, Recall, Power ( TP/P ):", round(sum(heart_pred == "yes" & heart_clean[-train,]$num_binary == "yes") / sum(heart_clean[-train,]$num_binary == "yes"), 3))
[1] "Sensitivity, Recall, Power ( TP/P ): 0.822"
Code
# 1 - Specificity (Type I Error): FP/N 
paste("1 - Specificity, Type II Error ( FP/N ):", round(sum(heart_pred == "no" & heart_clean[-train,]$num_binary == "no") / sum(heart_clean[-train,]$num_binary == "no"), 3))
[1] "1 - Specificity, Type II Error ( FP/N ): 0.654"
Code
# Precision: TP/P*
paste("Precision ( TP/P* ):", round(sum(heart_pred == "yes" & heart_clean[-train,]$num_binary == "yes") / sum(heart_pred == "yes"), 3))
[1] "Precision ( TP/P* ): 0.69"

Advantags and Disadvantes of Trees (ISLR2 authors)

Advantages

  • Easy to explain
  • Mirror human decision-making
  • Can be displayed graphically
  • Can handle qualitative predictors without the need for dummy variables

Disadvantages

  • Poor predictive accuracy
  • “Non-robust” (Small changes in the data can lead to large changes in the final estimated tree.)

Bagging and Random Forests

Bagging

Bagging

Bagging

Bootstrap aggregating

Description
Bootstrap aggregating (bagging) is a general-purpose procedure for reducing the variance of a statistical learning method.

The basic idea is to average multiple models to reduce the variance of the model.

Here, the bootstrap method involves repeatedly sampling observations from the training data set, fitting a model to each sample, and then combining the models to create a single predictive model.

Bagging Summary

To apply bagging to regression trees, we simply construct B regression trees using B bootstrapped training sets, and average the resulting predictions. These trees are grown deep, and are not pruned. Hence each individual tree has high variance, but low bias. Averaging these B trees reduces the variance. Bagging has been demonstrated to give impressive improvements in accuracy by combining together hundreds or even thousands of trees into a single procedure.

ISLP2, p 341

Bagging Example (1/6)

library(ranger)
library(caret)

Bagging Example (2/6)

library(ranger)
library(caret)

train_index <- createDataPartition(heart_clean$num_binary, p = 0.7, list = FALSE)
train_data <- heart_clean[train_index, ]
test_data <- heart_clean[-train_index, ]

Bagging Example (3/6)

library(ranger)
library(caret)

train_index <- createDataPartition(heart_clean$num_binary, p = 0.7, list = FALSE)
train_data <- heart_clean[train_index, ]
test_data <- heart_clean[-train_index, ]

# Train random forest model - note probability = FALSE for class predictions
rf_model <- ranger(
  num_binary ~ .,           # Formula: predict target using all other variables
  data = train_data,        # Training data
  num.trees = 500,          # Number of trees
  mtry = 13,                # Use all variables for each split 
  importance = 'impurity',  # Calculate variable importance (Gini index for classification)
  probability = FALSE       # Get class predictions instead of probabilities
)

Bagging Example (4/6)

library(ranger)
library(caret)

train_index <- createDataPartition(heart_clean$num_binary, p = 0.7, list = FALSE)
train_data <- heart_clean[train_index, ]
test_data <- heart_clean[-train_index, ]

# Train random forest model - note probability = FALSE for class predictions
rf_model <- ranger(
  num_binary ~ .,           # Formula: predict target using all other variables
  data = train_data,        # Training data
  num.trees = 500,          # Number of trees
  mtry = 13,                # Use all variables for each split 
  importance = 'impurity',  # Calculate variable importance (Gini index for classification)
  probability = FALSE       # Get class predictions instead of probabilities
)

# Make predictions on test set - directly get class predictions
predictions <- predict(rf_model, test_data)
pred_class <- predictions$predictions  # Class predictions

Bagging Example (5/6)

library(ranger)
library(caret)

train_index <- createDataPartition(heart_clean$num_binary, p = 0.7, list = FALSE)
train_data <- heart_clean[train_index, ]
test_data <- heart_clean[-train_index, ]

# Train random forest model - note probability = FALSE for class predictions
rf_model <- ranger(
  num_binary ~ .,           # Formula: predict target using all other variables
  data = train_data,        # Training data
  num.trees = 500,          # Number of trees
  mtry = 13,                # Use all variables for each split 
  importance = 'impurity',  # Calculate variable importance (Gini index for classification)
  probability = FALSE       # Get class predictions instead of probabilities
)

# Make predictions on test set - directly get class predictions
predictions <- predict(rf_model, test_data)
pred_class <- predictions$predictions  # Class predictions

# Calculate accuracy
accuracy <- mean(pred_class == test_data$num_binary)

# Get variable importance
var_importance <- data.frame(
  Feature = names(importance(rf_model)),
  Importance = importance(rf_model)
)
# var_importance <- var_importance[order(var_importance$Importance, decreasing = TRUE), ]
var_importance <- as_tibble(var_importance) |> 
  arrange(desc(Importance))

var_importance

Bagging Example (6/6)

Code
library(ranger)
library(caret)

train_index <- createDataPartition(heart_clean$num_binary, p = 0.7, list = FALSE)
train_data <- heart_clean[train_index, ]
test_data <- heart_clean[-train_index, ]

# Train random forest model - note probability = FALSE for class predictions
rf_model <- ranger(
  num_binary ~ .,           # Formula: predict target using all other variables
  data = train_data,        # Training data
  num.trees = 500,          # Number of trees
  mtry = 13,                # Use all variables for each split 
  importance = 'impurity',  # Calculate variable importance (Gini index for classification)
  probability = FALSE       # Get class predictions instead of probabilities
)

# Make predictions on test set - directly get class predictions
predictions <- predict(rf_model, test_data)
pred_class <- predictions$predictions  # Class predictions

# Calculate accuracy
accuracy <- mean(pred_class == test_data$num_binary)

# Get variable importance
var_importance <- data.frame(
  Feature = names(importance(rf_model)),
  Importance = importance(rf_model)
)
# var_importance <- var_importance[order(var_importance$Importance, decreasing = TRUE), ]
var_importance <- as_tibble(var_importance) |> 
  arrange(desc(Importance))

var_importance

Evaluation of Bagging Model

Code
# Print results
print("Model Performance:")
[1] "Model Performance:"
Code
print(paste("Accuracy:", round(accuracy, 3)))
[1] "Accuracy: 0.811"
Code
print(paste("Senstivity:", round(sum(pred_class == "yes" & test_data$num_binary == "yes") / sum(test_data$num_binary == "yes"), 3)))
[1] "Senstivity: 0.78"
Code
print(paste("Precision:", round(sum(pred_class == "yes" & test_data$num_binary == "yes") / sum(pred_class == "yes"), 3)))
[1] "Precision: 0.8"
Code
print("Confusion Matrix:")
[1] "Confusion Matrix:"
Code
print(table(Predicted = pred_class, Actual = test_data$num_binary))
         Actual
Predicted no yes
      no  41   9
      yes  8  32

Variable Importance Plot

Code
var_importance |> 
  ggplot(aes(y = reorder(Feature, Importance), x = Importance)) +
  geom_col(fill = "skyblue") +
  labs(
    title = "Variable Importance Plot for Heart Data",
    subtitle = "Mean decrease in Gini index per variable",
    x = "Variable Importance",
    y = NULL
  ) +
  theme_light() +
  theme(
    plot.title.position = "plot"
  )

Ou-of-Bag Error Estimation

It turns out that there is a very straightforward way to estimate the test error of a bagged model, without the need to perform cross-validation or the validation set approach. Recall that the key to bagging is that trees are repeatedly fit to bootstrapped subsets of the observations. One can show that on average, each bagged tree makes use of around two-thirds of the observations. The remaining one-third of the observations not used to fit a given bagged tree are referred to as the out-of-bag (OOB) observations.

. . .

The resulting OOB error is a valid estimate of the test error for the bagged model, since the response for each observation is predicted using only the trees that were not fit using that observation.

ISLR2, p 342

Out-of-Bag Error Plot and results

Code
# To track OOB error by number of trees, we can create multiple models
tree_numbers <- 1:500
oob_errors <- numeric(length(tree_numbers))

for(i in seq_along(tree_numbers)) {
  rf_temp <- ranger(
    num_binary ~ .,
    data = heart_clean,
    num.trees = tree_numbers[i],
    importance = 'impurity',
    probability = FALSE,
    oob.error = TRUE
  )
  oob_errors[i] <- rf_temp$prediction.error
}

# Create data frame of results
oob_results <- data.frame(
  Trees = tree_numbers,
  OOB_Error = oob_errors
)

# Print OOB error progression
# print("OOB Error by Number of Trees:")
# print(oob_results)

oob_results |> 
  ggplot(aes(x = Trees, y = OOB_Error)) +
  geom_line() +
  theme_light()

Code
# Additional model information
print("\nModel Information:")
[1] "\nModel Information:"
Code
print(paste("Number of trees:", rf_model$num.trees))
[1] "Number of trees: 500"
Code
print(paste("Number of independent variables:", rf_model$num.independent.variables))
[1] "Number of independent variables: 13"
Code
print(paste("Mtry:", rf_model$mtry))
[1] "Mtry: 13"
Code
# Get variable importance with OOB error increase
var_importance <- data.frame(
  Feature = names(importance(rf_model)),
  Importance = importance(rf_model)
)
var_importance <- var_importance[order(var_importance$Importance, decreasing = TRUE), ]

Random Forest Example

Code
library(ranger)
library(caret)

set.seed(123)
train_index <- createDataPartition(heart_clean$num_binary, p = 0.7, list = FALSE)
train_data <- heart_clean[train_index, ]
test_data <- heart_clean[-train_index, ]

# Train random forest model - note probability = FALSE for class predictions
rf_model <- ranger(
  num_binary ~ .,           # Formula: predict target using all other variables
  data = train_data,        # Training data
  num.trees = 500,          # Number of trees
  mtry = floor(sqrt(13)),   # Use sqrt(p) variables for each split
  importance = 'impurity',  # Calculate variable importance (Gini index for classification)
  probability = FALSE       # Get class predictions instead of probabilities
)

# Make predictions on test set - directly get class predictions
predictions <- predict(rf_model, test_data)
pred_class <- predictions$predictions  # Class predictions

# Calculate accuracy
accuracy <- mean(pred_class == test_data$num_binary)

# Calculate precision
precision <- sum(pred_class == "yes" & test_data$num_binary == "yes") / sum(pred_class == "yes")

# Calculate sensitivity
sensitivity <- sum(pred_class == "yes" & test_data$num_binary == "yes") / sum(test_data$num_binary == "yes")

print("Model Performance:")
[1] "Model Performance:"
Code
print(paste("Accuracy:", round(accuracy, 3)))
[1] "Accuracy: 0.822"
Code
print(paste("Senstivity:", round(sensitivity, 3)))
[1] "Senstivity: 0.756"
Code
print(paste("Precision:", round(precision, 3)))
[1] "Precision: 0.838"

Homework

End of Module 12